MCMC: Convergence and Example


Lecture 15

March 20, 2024

Last Class(es)

Markov Chain Strategy

  • Generate an appropriate Markov chain so that its stationary distribution of the target distribution \(\pi\);
  • Run its dynamics long enough to converge to the stationary distribution;
  • Use the resulting ensemble of states as Monte Carlo samples from \(\pi\) .

Markov Chain Convergence

Given a Markov chain \(\{X_t\}_{t=1, \ldots, T}\) returned from this procedure, sampling from distribution \(\pi\):

  • \(\mathbb{P}(X_t = y) \to \pi(y)\) as \(t \to \infty\)
  • This means the chain can be considered a dependent sample approximately distributed from \(\pi\).
  • The first values (the transient portion) of the chain are highly dependent on the initial value.

The Metropolis-Hastings Algorithm

Given \(X_t = x_t\):

  1. Generate \(Y_t \sim q(y | x_t)\);
  2. Set \(X_{t+1} = Y_t\) with probability \(\rho(x_t, Y_t)\), where \[\rho(x, y) = \min \left\{\frac{\pi(y)}{\pi(x)}\frac{q(x | y)}{q(y | x)}, 1\right\},\] else set \(X_{t+1} = x_t\).

Proposals

  • “Goldilocks” proposal: acceptance rate 30-45%.
  • Proposal distribution \(q\) plays a big role in the effective sample size (ESS): \[N_\text{eff} = \frac{N}{1+2\sum_{t=1}^\infty \rho_t}\]

Sampling Efficiency Example

MCMC Sampling for Various Proposals

MCMC Convergence

Transient Chain Portion

What do we do with the transient portion of the chain?

  • Discard as burn-in;
  • Just run the chain longer.

How To Identify Convergence?

Short answer: There is no guarantee! Judgement based on an accumulation of evidence from various heuristics.

  • The good news — getting the precise “right” end of the transient chain doesn’t matter.
  • If a few transient iterations remain, the effect will be washed out with a large enough post-convergence chain.

Heuristics for Convergence

Compare distribution (histogram/kernel density plot) after half of the chain to full chain.

2000 Iterations

10000 Iterations
Figure 1

Gelman-Rubin Diagnostic

Gelman & Rubin (1992)

  • Run multiple chains from “overdispersed” starting points
  • Compare intra-chain and inter-chain variances
  • Summarized as \(\hat{R}\) statistic: closer to 1 implies better convergence.
  • Can also check distributions across multiple chains vs. the half-chain check.

On Multiple Chains

Unless a specific scheme is used, multiple chains are not a solution for issues of convergence, as each individual chain needs to converge and have burn-in discarded/watered-down.

This means multiple chains are more useful for diagnostics, but once they’ve all been run long enough, can mix samples freely.

Heuristics for Convergence

  • If you’re more interested in the mean estimate, can also look at the its stability by iteration or the Monte Carlo standard error.
  • Look at traceplots; do you see sudden “jumps”?
  • When in doubt, run the chain longer.

Increasing Efficiency

Adaptive Metropolis-Hastings

Adjust proposal density to hit target acceptance rate.

  • Need to be cautious about detailed balance.
  • Typical strategy is to adapt for a portion of the initial chain (part of the burn-in), then run longer with that proposal.

Hamiltonian Monte Carlo

  • Idea: Use proposals which steer towards “typical set” without collapsing towards the mode (based on Hamiltonian vector field);
  • Requires gradient information: can be obtained through autodifferentiation; challenging for external models;
  • Can be very efficient due to potential for anti-correlated samples, but very sensitive to parameterization.
  • Same principles for evaluating convergence apply.

MCMC Example: Modeling Storm Surge Extremes

Data

Code
# load SF tide gauge data
# read in data and get annual maxima
function load_data(fname)
    date_format = DateFormat("yyyy-mm-dd HH:MM:SS")
    # This uses the DataFramesMeta.jl package, which makes it easy to string together commands to load and process data
    df = @chain fname begin
        CSV.read(DataFrame; header=false)
        rename("Column1" => "year", "Column2" => "month", "Column3" => "day", "Column4" => "hour", "Column5" => "gauge")
        # need to reformat the decimal date in the data file
        @transform :datetime = DateTime.(:year, :month, :day, :hour)
        # replace -99999 with missing
        @transform :gauge = ifelse.(abs.(:gauge) .>= 9999, missing, :gauge)
        select(:datetime, :gauge)
    end
    return df
end

dat = load_data("data/surge/h551.csv")

# detrend the data to remove the effects of sea-level rise and seasonal dynamics
ma_length = 366
ma_offset = Int(floor(ma_length/2))
moving_average(series,n) = [mean(@view series[i-n:i+n]) for i in n+1:length(series)-n]
dat_ma = DataFrame(datetime=dat.datetime[ma_offset+1:end-ma_offset], residual=dat.gauge[ma_offset+1:end-ma_offset] .- moving_average(dat.gauge, ma_offset))

# group data by year and compute the annual maxima
dat_ma = dropmissing(dat_ma) # drop missing data
dat_annmax = combine(dat_ma -> dat_ma[argmax(dat_ma.residual), :], groupby(DataFrames.transform(dat_ma, :datetime => x->year.(x)), :datetime_function))
delete!(dat_annmax, nrow(dat_annmax)) # delete 2023; haven't seen much of that year yet
rename!(dat_annmax, :datetime_function => :Year)
select!(dat_annmax, [:Year, :residual])

# make plots
p1 = plot(
    dat_annmax.Year,
    dat_annmax.residual;
    xlabel="Year",
    ylabel="Annual Max Tide Level (mm)",
    label=false,
    marker=:circle,
    markersize=5,
    tickfontsize=16,
    guidefontsize=18
)
p2 = histogram(
    dat_annmax.residual,
    normalize=:pdf,
    orientation=:horizontal,
    label=:false,
    xlabel="PDF",
    xlims=(0, 0.006),
    ylabel="",
    yticks=[],
    xticks = [],
    tickfontsize=16,
    guidefontsize=18
)

l = @layout [a{0.7w} b{0.3w}]
plot(p1, p2; layout=l, link=:y, ylims=(1000, 1700), bottom_margin=5mm, left_margin=5mm)
plot!(size=(1000, 450))
Figure 2: Annual maxima surge data from the San Francisco, CA tide gauge.

Probability Model (Annual Maxima)

\[\begin{gather*} y_t \sim \text{GEV}(\mu, \sigma, \xi) \\ \mu \sim \mathcal{LogNormal}(7, 0.25) \\ \sigma \sim \mathcal{TN}(0, 100; 0, \infty) \\ \xi \sim \mathcal{N}(0, 0.1) \end{gather*}\]

Prior Predictive Check

Code
# sample from priors
μ = rand(LogNormal(7, 0.25), 1000)
σ = rand(truncated(Normal(0, 100), lower=0), 1000)
ξ = rand(Normal(0, 0.1), 1000)
# simulate
# define return periods and cmopute return levels for parameters
return_periods = 2:100
return_levels = zeros(1_000, length(return_periods))
for i in 1:1_000
    return_levels[i, :] = quantile.(GeneralizedExtremeValue(μ[i], σ[i], ξ[i]), 1 .- (1 ./ return_periods))
end

plt_prior_1 = plot(; ylabel="Return Level (m)", xlabel="Return Period (yrs)", tickfontsize=16, legendfontsize=18, guidefontsize=18, bottom_margin=10mm, left_margin=10mm, legend=:topleft)
for idx in 1:1_000
    label = idx == 1 ? "Prior" : false
    plot!(plt_prior_1, return_periods, return_levels[idx, :]; color=:black, alpha=0.1, label=label)
end
plt_prior_1
Figure 3: Prior predictive check for surge model.

Probabilistic Programming Languages

  • Rely on more advanced methods (e.g. Hamiltonian Monte Carlo) to draw samples more efficiently.
  • Use automatic differentiation to compute gradients.
  • Syntax closely resembles statistical model specification.
  • Examples:

Turing Model Specification

@model function sf_surge(y)
    ## pick priors
    μ ~ LogNormal(7, 0.25) # location
    σ ~ truncated(Normal(0, 100); lower=0) # scale
    ξ ~ Normal(0, 0.1) # shape

    ## likelihood
    y .~ GeneralizedExtremeValue(μ, σ, ξ)
end
sf_surge (generic function with 2 methods)

Sampling with Turing

surge_chain = let # variables defined in a let...end block are temporary
    model = sf_surge(dat_annmax.residual) # initialize model with data
    sampler = NUTS() # use the No-U-Turn Sampler; there are other options
    nsamples = 10_000
    sample(model, sampler, nsamples; drop_warmup=true)
end
summarystats(surge_chain)
Summary Statistics
  parameters        mean       std      mcse    ess_bulk    ess_tail      rhat     Symbol     Float64   Float64   Float64     Float64     Float64   Float64 ⋯
           μ   1258.8094    5.5677    0.0613   8246.1454   6832.8697    1.0000 ⋯
           σ     57.4330    4.1522    0.0472   7705.8149   6328.2079    1.0003 ⋯
           ξ      0.0172    0.0517    0.0006   7025.3590   5849.6953    1.0003 ⋯
                                                                1 column omitted

Sampling Visualization

Code
plot(surge_chain, size=(1200, 500), left_margin=5mm, bottom_margin=5mm)
Figure 4: Sampler visualization for surge chain

Optimizing with Turing

We can also use Turing.jl along with Optim.jl to get the MLE and MAP.

MLE

mle_surge = optimize(sf_surge(dat_annmax.residual), MLE())
coeftable(mle_surge)
Coef. Std. Error z Pr(> z )
μ 1258.71 5.61428 224.198 0.0 1247.71 1269.71
σ 56.2665 4.08661 13.7685 3.94289e-43 48.2569 64.2761
ξ 0.0171937 0.0624531 0.275306 0.783081 -0.105212 0.139599

MAP

map_surge = optimize(sf_surge(dat_annmax.residual), MAP())
coeftable(map_surge)
Coef. Std. Error z Pr(> z )
μ 1258.73 5.52041 228.014 0.0 1247.91 1269.55
σ 56.2336 4.04126 13.9149 5.14484e-44 48.3129 64.1543
ξ 0.0129045 0.0523448 0.246528 0.805273 -0.0896894 0.115498

Posterior Visualization

Code
p1 = histogram(surge_chain[:μ], label="Samples", normalize=:pdf, legend=:topleft, xlabel=L"μ", ylabel=L"p(μ|y)")
p2 = histogram(surge_chain[:σ], label="Samples", normalize=:pdf, legend=:topleft, xlabel=L"σ", ylabel=L"p(σ|y)")
p3 = histogram(surge_chain[:ξ], label="Samples", normalize=:pdf, legend=:topleft, xlabel=L"σ", ylabel=L"p(σ|y)")
p = plot(p1, p2, p3, tickfontsize=16, guidefontsize=18, legendfontsize=18, left_margin=10mm, bottom_margin=10mm, layout = @layout [a b c])
vline!(p, mean(surge_chain)[:, 2]', color=:purple, linewidth=3, label="Posterior Mean")
plot!(p, size=(1200, 450))
Figure 5: Posterior visualization for surge chain

Correlations

Code
p1 = histogram2d(surge_chain[:μ], surge_chain[:σ], normalize=:pdf, legend=false, xlabel=L"μ", ylabel=L"σ")
p2 = histogram2d(surge_chain[:μ], surge_chain[:ξ], normalize=:pdf, legend=false, xlabel=L"μ", ylabel=L"ξ")
p3 = histogram2d(surge_chain[:σ], surge_chain[:ξ], normalize=:pdf, legend=false, xlabel=L"σ", ylabel=L"ξ")
p = plot(p1, p2, p3, tickfontsize=16, guidefontsize=18, left_margin=5mm, bottom_margin=5mm, layout = @layout [a b c])
plot!(p, size=(1200, 450))
Figure 6: Posterior correlations

Posterior Predictive Checks

Code
plt_rt = plot(; ylabel="Return Level (m)", xlabel="Return Period (yrs)", tickfontsize=16, legendfontsize=18, guidefontsize=18, bottom_margin=10mm, left_margin=10mm, legend=:topleft)
for idx in 1:1000
    μ = surge_chain[:μ][idx]
    σ = surge_chain[:σ][idx]
    ξ = surge_chain[:ξ][idx]
    return_levels[idx, :] = quantile.(GeneralizedExtremeValue(μ, σ, ξ), 1 .- (1 ./ return_periods))
    label = idx == 1 ? "Posterior" : false
    plot!(plt_rt, return_periods, return_levels[idx, :]; color=:black, alpha=0.05, label=label, linewidth=0.5)
end
# plot return level quantiles
rl_q = mapslices(col -> quantile(col, [0.025, 0.5, 0.975]), return_levels, dims=1)
plot!(plt_rt, return_periods, rl_q[[1,3], :]', color=:green, linewidth=2, label="95% CI")
plot!(plt_rt, return_periods, rl_q[2, :], color=:red, linewidth=2, label="Posterior Median")
# plot data
scatter!(plt_rt, return_periods, quantile(dat_annmax.residual, 1 .- (1 ./ return_periods)), label="Data", color=:black)
plot!(plt_rt, size=(1200, 500))
plt_rt
Figure 7: Posterior predictive checks

Multiple Chains

surge_chain = let # variables defined in a let...end block are temporary
    model = sf_surge(dat_annmax.residual) # initialize model with data
    sampler = NUTS() # use the No-U-Turn Sampler; there are other options
    nsamples = 10_000
    nchains = 4
    sample(model, sampler, MCMCThreads(), nsamples, nchains; drop_warmup=true)
end
gelmandiag(surge_chain)
Gelman, Rubin, and Brooks diagnostic
  parameters      psrf    psrfci 
      Symbol   Float64   Float64 
           μ    1.0000    1.0002
           σ    1.0000    1.0000
           ξ    1.0001    1.0002

Plotting Multiple Chains

Code
plot(surge_chain)
plot!(size=(1200, 500))
Figure 8: Sampler visualization for multiple surge chains

Key Points and Upcoming Schedule

Key Points (Convergence)

  • Must rely on “accumulation of evidence” from heuristics for determination about convergence to stationary distribution.
  • Transient portion of chain: Meh. Some people worry about this too much. Discard or run the chain longer.
  • Parallelizing solves few problems, but running multiple chains can be useful for diagnostics.

Next Classes

Monday: MCMC Lab (No exercises these weeks)

Next Wednesday: Literature Presentations (email slides by 9pm Tuesday night).

Assessments

  • Homework 3: Due 3/22

References

References

Gelman, A., & Rubin, D. B. (1992). Inference from Iterative Simulation Using Multiple Simulations. Stat. Sci., 7, 457–511. https://doi.org/10.1214/ss/1177011136